function [rules_sep, rules_comb, dens_map, heat_map] = summarize_savings_graph_covariate(covariate,wave,cost,winter_prevuse,SMC,dimension,fbase)

% This function plots the treatment rules learned using the EWM method and the 
% heatmap illustrating the solution of the grid search algorithm. 

% Inputs: 
% (1) covariate: income, size, vintage
% (2) wave: indicates whether the output is wave-specific (=3,6,or 7) or
% for the pooled sample (=0)
% (3) cost: =true if welfare measure is in terms of cost savings using private marginal cost 
% =false if welfare measure is in terms of kWh reduction
% (4) winter_prevuse: indicates whether baseline consumption is calculated 
% as the mean of consumption in winter months (Jan and Feb) or as the mean
% of specified pre-treatment periods.
% (5) SMC: =1 use social marginal cost; =0 use retail electricity price
% (6) dimension: ="two-dim" if plotting the rule in two-dimensional characteristics space;
% ="one-dim" if plotting the rule in terms of pre-treatment consumption only 
% (7) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% Output: 
% (1) rules_sep: separate plots showing the quadrant and cubic treatment rule
% (2) rules_comb: combine quadrant and cubic treatment rule in one plot
% (3) dens_map: heatmap showing the grid search results, accompanied with
% density plots for the two covariates
% (4) heat_map: combine heatmap with the quadrant rule in one plot 

%% specify input file
% input quadrant results
switch dimension 
    case "two-dim"
if (cost)
    tcost = 0.765; 
else
    tcost = 0;
end

switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end
filename_coefs=sprintf('coef_quadrant_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
file = sprintf('%s_%s',fbase,filename_coefs);
S_quadrant=load(file);

%extract content of S_quadrant
c_fieldnames = fieldnames(S_quadrant);
for ifield = 1:length(c_fieldnames)
    field_ = c_fieldnames{ifield};
    eval(sprintf('%s=S_quadrant.%s;',field_,field_))
end

in_Ghat_quadrant=in_Ghat;
gu_quadrant=gu;

% input cubic results
switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end
filename_coefs=sprintf('coef_cubic_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
file = sprintf('%s_%s',fbase,filename_coefs);
S=load(file);
%extract content of S
c_fieldnames = fieldnames(S);
for ifield = 1:length(c_fieldnames)
    field_ = c_fieldnames{ifield};
    eval(sprintf('%s=S.%s;',field_,field_))
end

% "beta" conflicts with a built-in function, have to do this extra
% assignment
beta=S.beta;
in_Ghat_cubic=in_Ghat;

%% Visualize treatment rule separately for cubic and quadrant rules
rules_sep=figure('position',[100,100,1100,350]); hold on;
subplot(1,2,1); hold on;
gscatter(Xu(:,1), Xu(:,2), in_Ghat_quadrant, 'br', '..');
plot([min(v_x1),max(v_x1)],v_x2(min_i2)*[1,1],'k--')
plot(v_x1(min_i1)*[1,1],[min(v_x2),max(v_x2)],'k--')
axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
title("Optimal quadrant treatment rule")
ylabel('Pretreatment Usage [kWh/mo]')
legend('\color{blue} Not treat','\color{red} Treat')
if covariate=="income"
    xlabel('Income [$ k]');
elseif covariate=="size"
    xlabel('House size [sq ft]');
elseif covariate=="vintage"
    xlabel('House construction year');
end

subplot(1,2,2); hold on;
gscatter(x1_wave,prevuse_wave, in_Ghat_cubic, 'br', '..');
title("Optimal cubic treatment rule")
ylabel('Pretreatment Usage [kWh/mo]')
legend('\color{blue} Not treat','\color{red} Treat')
if covariate=="income"
   xlabel('Income [$ k]');
elseif covariate=="size"
   xlabel('House size [sq ft]');
elseif covariate=="vintage"
   xlabel('House construction year');
end
hold off;

%% Quadrant rule heat map
m_W_perPerson = m_W/n;
k1=size(m_W,1);
k2=size(m_W,2);

% Calculate number of households by grid
N_Person_in_Grid = zeros(size(m_W));
for i1=1:k1
    if i1 == 1
        x1_prv = 0;
    else
        x1_prv = v_x1(i1-1);
    end
    x1_now = v_x1(i1);
    bin_x1 = (Xu(:,1)> x1_prv).* (Xu(:,1)<=x1_now);
    for i2=1:k2        
        if i2 == 1
            x2_prv = 0;
        else
            x2_prv = v_x2(i2-1);
        end
        x2_now = v_x2(i2);
        bin_x2 = (Xu(:,2)> x2_prv).* (Xu(:,2)<=x2_now);
        %
        N_Person_in_Grid(i1,i2)=sum(nw(logical(bin_x1.*bin_x2)));
    end
end


% figure initialize
x0=10;
y0=10;
width=800;
height=600;

dens_map=figure; hold on;
set(gcf,'position',[x0,y0,width,height]);

colormap gray
N_Row = 6; N_Col = 5;

% heatmap
subplot(N_Row,N_Col,[(3:N_Col),(3:N_Col)+N_Col*1,(3:N_Col)+N_Col*2]); hold on;
imagesc(v_x1,v_x2,m_W_perPerson');
plot([min(v_x1),max(v_x1)],v_x2(min_i2)*[1,1],'w--')
plot(v_x1(min_i1)*[1,1],[min(v_x2),max(v_x2)],'w--')
set(gca,'ydir','normal'); 
if covariate=="income"
    xlabel('Income [$ k]');
elseif covariate=="size"
    xlabel('House size [sq ft]');
elseif covariate=="vintage"
    xlabel('House construction year');
end
ylabel({'Pretreatment Usage', '[kWh/mo]'})
axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
colorbar();
set(findall(gcf,'-property','FontSize'),'FontSize',22)

% y density plot
subplot(N_Row,N_Col,[(1),(1)+N_Col*1,(1)+N_Col*2])
plot(sum(N_Person_in_Grid,1)./sum(N_Person_in_Grid,'all'),v_x2,'k','LineWidth',1.5); 
set(gca,'color','none','ytick',[],'xdir','reverse','YAxisLocation','origin'); 
ylim([min(v_x2),max(v_x2)]); box off;
xlabel('Share of hh')
set(findall(gcf,'-property','FontSize'),'FontSize',22)

% x density plot
subplot(N_Row,N_Col,(3:N_Col)+N_Col*4)
plot(v_x1,sum(N_Person_in_Grid,2)/sum(N_Person_in_Grid,'all'),'k','LineWidth',1.5); 
xlim([min(v_x1),max(v_x1)]); box off;
set(gca,'color','none','xtick',[]);
ylabel('Share of hh')
set(findall(gcf,'-property','FontSize'),'FontSize',22)

hold off;

%% Plot quadrant and cubic results together 
% quadrant plot inputs
x0=10;
y0=10;
width=1000;
height=800;

rules_comb=figure; hold on;
set(gcf,'position',[x0,y0,width,height]);
set(gcf,'defaultAxesFontSize',26);
set(gcf,'defaultAxesFontName','Arial');

if min_sign1<0 && min_sign2<0
    x_coord=[min(v_x1) min(v_x1) v_x1(min_i1) v_x1(min_i1)];
    y_coord=[min(v_x2) v_x2(min_i2) v_x2(min_i2) min(v_x2)];
elseif min_sign1<0 && min_sign2>0 
    x_coord=[min(v_x1) min(v_x1) v_x1(min_i1) v_x1(min_i1)];
    y_coord=[v_x2(min_i2) max(v_x2) max(v_x2) v_x2(min_i2)];
elseif min_sign1>0 && min_sign2<0 
    x_coord=[v_x1(min_i1) v_x1(min_i1) max(v_x1) max(v_x1)];
    y_coord=[min(v_x2) v_x2(min_i2) v_x2(min_i2) min(v_x2)];
elseif min_sign1>0 && min_sign2>0    
    x_coord=[v_x1(min_i1) v_x1(min_i1) max(v_x1) max(v_x1)];
    y_coord=[v_x2(min_i2) max(v_x2) max(v_x2) v_x2(min_i2)];
end 

% cubic plot inputs 
line_x1=[min(x1_wave):1:max(x1_wave)]';
% calculate the usage cutoff for treatment at different income levels 
line_usage=Xscale(1,4)*...
               (-(beta(1)+line_x1*beta(2)./Xscale(1,1)...
               +(line_x1.^2)*beta(3)./Xscale(1,2)...
               +(line_x1.^3)*beta(4)./Xscale(1,3))./beta(5));
           
if (beta(5)>0) % treatment rule is increasing in pre-treatment consumption
    select=(line_usage<=max(prevuse_wave));
    patch_x=[line_x1(select); flipud(line_x1(select))];
    patch_y=[max(line_usage(select),min(prevuse_wave)); max(prevuse_wave)*ones(sum(select),1) ];
else           % treatment rule is decreasing in pre-treatment consumption
    select=(line_usage>=min(prevuse_wave));
    patch_x=[line_x1(select); flipud(line_x1(select))];
    patch_y=[min(line_usage(select),max(prevuse_wave)); min(prevuse_wave)*ones(sum(select),1) ];
end

N_Row = 6; N_Col = 5;

% plot rules
subplot(N_Row,N_Col,[(3:N_Col),(3:N_Col)+N_Col*1,(3:N_Col)+N_Col*2]); hold on;

patch (x_coord, y_coord, 'red', 'LineStyle', '-', 'FaceAlpha', 0.3,'EdgeColor', 'none')
patch (patch_x,patch_y,'blue', 'LineStyle', '-', 'FaceAlpha', 0.45,'EdgeColor', 'none')
if covariate=="vintage"
    axis([min(patch_x),max(patch_x),min(v_x2),max(v_x2)]);
else
    axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
end
lgd=legend('Quadrant rule treat','Cubic rule treat');
set(lgd,'FontSize',16);

% y density plot
subplot(N_Row,N_Col,[(1),(1)+N_Col*1,(1)+N_Col*2])
plot(sum(N_Person_in_Grid,1)./sum(N_Person_in_Grid,'all'),v_x2,'k','LineWidth',1.5); 
set(gca,'color','none','ytick',[],'xdir','reverse','YAxisLocation','origin'); 
ylim([min(v_x2),max(v_x2)]); box off;
xlabel('Share of hh')
ylabel({'Pretreatment Usage', '[kWh/mo]'})

% x density plot
subplot(N_Row,N_Col,(3:N_Col)+N_Col*4)
plot(v_x1,sum(N_Person_in_Grid,2)/sum(N_Person_in_Grid,'all'),'k','LineWidth',1.5); 
xlim([min(v_x1),max(v_x1)]); box off;
set(gca,'color','none','xtick',[]);
ylabel('Share of hh')
if covariate=="income"
    xlabel('Income [$ k]');
elseif covariate=="size"
    xlabel('House size [sq ft]');
elseif covariate=="vintage"
    xlabel('House construction year');
end

hold off;


%% heatmap combined with treatment rules
k1=size(m_W,1);
k2=size(m_W,2);

N_Person_in_Grid = zeros(size(m_W));
g_in_Grid=zeros(size(m_W));

for i1=1:k1
    if i1 == 1
        x1_prv = 0;
    else
        x1_prv = v_x1(i1-1);
    end
    x1_now = v_x1(i1);
    bin_x1 = (Xu(:,1)> x1_prv).* (Xu(:,1)<=x1_now);
    for i2=1:k2        
        if i2 == 1
            x2_prv = 0;
        else
            x2_prv = v_x2(i2-1);
        end
        x2_now = v_x2(i2);
        bin_x2 = (Xu(:,2)> x2_prv).* (Xu(:,2)<=x2_now);
        
        N_Person_in_Grid(i1,i2)=sum(nw(logical(bin_x1.*bin_x2)));
        g_in_Grid(i1,i2)=sum(gu_quadrant(logical(bin_x1.*bin_x2)));
    end
end

mean_g_in_Grid=g_in_Grid./N_Person_in_Grid;
mean_g_in_Grid(isnan(mean_g_in_Grid))=0;
mean_g_in_Grid_clean=mean_g_in_Grid;

prctile(mean_g_in_Grid,99,'all')
prctile(mean_g_in_Grid,1,'all')
mean_g_in_Grid_clean(mean_g_in_Grid>prctile(mean_g_in_Grid,95,'all'))=prctile(mean_g_in_Grid,95,'all');
mean_g_in_Grid_clean(mean_g_in_Grid<prctile(mean_g_in_Grid,5,'all'))=prctile(mean_g_in_Grid,5,'all');

x0=10;
y0=10;
width=800;
height=600;

heat_map=figure; hold on;
set(gcf,'position',[x0,y0,width,height]);
set(findall(gcf,'-property','FontSize'),'FontSize',20);

imagesc(v_x1,v_x2,mean_g_in_Grid_clean');
patch (x_coord, y_coord, 'magenta', 'LineStyle', '-', 'FaceAlpha', 0.3,'EdgeColor', 'magenta')
patch (patch_x,patch_y,[0.75 0.75 0.75], 'LineStyle', '-', 'FaceAlpha', 0.7,'EdgeColor', [0.75 0.75 0.75])
plot([min(v_x1),max(v_x1)],v_x2(min_i2)*[1,1],'w--')
plot(v_x1(min_i1)*[1,1],[min(v_x2),max(v_x2)],'w--')
if (cost)
    title("Cost savings per household-month($)"); 
else 
    title("Energy savings per household-month(kWh)"); 
end
set(gca,'ydir','normal'); 
if covariate=="income"
    xlabel('Income [$ k]');
elseif covariate=="size"
    xlabel('House size [sq ft]');
elseif covariate=="vintage"
    xlabel('House construction year');
end
ylabel('Pretreatment Usage [kWh/mo]')
axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
colorbar();
hold off;

    case "one-dim"

switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (cost)>0 
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
else
    s_cost='kwh';
end
filename_coefs=sprintf('coef_onedim_%s_baseline_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
file = sprintf('%s_%s',fbase,filename_coefs);
load(file);

Xu(:,2)=ones(size(Xu,1),1);

x0=10;
y0=10;
width=800;
height=600;

rules_sep=figure; hold on;
set(gcf,'position',[x0,y0,width,height]);
set(findall(gcf,'-property','FontSize'),'FontSize',20);

gscatter(Xu(:,2),Xu(:,1), in_Ghat, 'kr', '..',[30 30]);
plot([0.5 1.5],v_x1(min_i1)*[1 1],'k--');
xlim([0.5 1.5])
set(gca,'XTick',[])

scatter(Xu(:,2)*1.2,Xu(:,1),nw+1,'MarkerEdgeColor','none','MarkerFaceColor',[0.7 .7 0.7]);
ylabel('Pre-treatment Usage [kWh/mo]')
legend('','\color{red} Treat','Treatment cutoff','Population density')

hold off;

rules_comb=figure;
end
end